# =============================================================================
#  Analysis 2: This example script extends the main protocol analysis to be more specific for coloured two cell populations 
# =============================================================================

# =============================================================================
# =============================================================================
# #  This repeats the preliminary analysis of the main protocol
# =============================================================================
# =============================================================================
"""
1. Read the Video
"""
from MOSES.Utility_Functions.file_io import read_multiimg_PIL

infile = '../Data/Videos/c_EPC2(G)_CP-A(R)_KSFM+RPMI_5_Fast GFP.tif_R-G.tif'
vidstack = read_multiimg_PIL(infile)

# Check the read size was expected. 
n_frame, n_rows, n_cols, n_channels = vidstack.shape
print('Size of video: (%d,%d,%d,%d)' %(n_frame,n_rows,n_cols,n_channels))


"""
2. Motion Extraction Parameters
"""
optical_flow_params = dict(pyr_scale=0.5, levels=3, winsize=15, iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
# number of superpixels
n_spixels = 1000

"""
3. Extract the superpixel tracks 
"""
from MOSES.Optical_Flow_Tracking.superpixel_track import compute_grayscale_vid_superpixel_tracks

# extract superpixel tracks for the 1st or 'red' channel
optflow_r, meantracks_r = compute_grayscale_vid_superpixel_tracks(vidstack[:,:,:,0], optical_flow_params, n_spixels)
# extract superpixel tracks for the 2nd or 'green' channel
optflow_g, meantracks_g = compute_grayscale_vid_superpixel_tracks(vidstack[:,:,:,1], optical_flow_params, n_spixels)

"""
4. Save the superpixel tracks
"""
import scipy.io as spio
import os 
fname = os.path.split(infile)[-1]
savetracksmat = ('meantracks_'+ fname.replace('tif','.mat'))
spio.savemat(savetracksmat, {'meantracks_r':meantracks_r,
                             'meantracks_g':meantracks_g})

"""
4b. Loading the superpixel tracks
"""
meantracks_r = spio.loadmat(savetracksmat)['meantracks_r']
meantracks_g = spio.loadmat(savetracksmat)['meantracks_g']

# =============================================================================
# =============================================================================
# #  From here on is new analysis
# =============================================================================
# =============================================================================
"""
2. Video kymograph
"""
from MOSES.Visualisation_Tools.kymographs import kymograph_img
import numpy as np 
import pylab as plt 

vid_max_slice = kymograph_img(vidstack, axis=1, proj_fn=np.max)

fig, ax = plt.subplots()
ax.imshow(vid_max_slice)
ax.set_aspect('auto')
#fig.savefig('video_kymograph_example.svg', bbox_inches='tight')
plt.show()


"""
3. Velocity kymograph
"""
import scipy.io as spio

# # load the previously saved optical flow 
# save_optflow_mat = ('optflow_'+fname).replace('.tif', '.mat')
# optflow_r = spio.loadmat(save_optflow_mat)['optflow_r']
# optflow_g = spio.loadmat(save_optflow_mat)['optflow_g']

optflow = optflow_r + optflow_g
optflow_x = optflow[...,0]
# kymograph 1: using the dense optical flow. 
optflow_median_slice = kymograph_img(optflow_x, axis=1, proj_fn=np.nanmedian)


from MOSES.Visualisation_Tools.kymographs import construct_spatial_time_MOSES_velocity_x
# load the previously computed superpixel tracks
savetracksmat = ('meantracks_'+fname).replace('.tif', '.mat')
meantracks_r = spio.loadmat(savetracksmat)['meantracks_r']
meantracks_g = spio.loadmat(savetracksmat)['meantracks_g']

# kymograph 2: Superpixel tracks. (fixed superpixel number)
velocity_kymograph_x_tracks_r = construct_spatial_time_MOSES_velocity_x(meantracks_r, shape=(n_frame, n_cols), n_samples=51, axis=1)
velocity_kymograph_x_tracks_g = construct_spatial_time_MOSES_velocity_x(meantracks_g, shape=(n_frame, n_cols), n_samples=51, axis=1)
velocity_kymograph_x_tracks = velocity_kymograph_x_tracks_r + velocity_kymograph_x_tracks_g


# compute the dense superpixel tracks first and save. 
print('computing dense tracks')
_, meantracks_r_dense = compute_grayscale_vid_superpixel_tracks(vidstack[:,:,:,0], optical_flow_params, n_spixels, dense=True, mindensity=1)
_, meantracks_g_dense = compute_grayscale_vid_superpixel_tracks(vidstack[:,:,:,1], optical_flow_params, n_spixels, dense=True, mindensity=1)

savetracksmat = ('meantracks-dense_'+fname).replace('.tif', '.mat')
spio.savemat(savetracksmat, {'meantracks_r':meantracks_r_dense, 
                             'meantracks_g':meantracks_g_dense})

# kymograph 3: using dense superpixel tracking. 
velocity_kymograph_x_tracks_r_dense = construct_spatial_time_MOSES_velocity_x(meantracks_r_dense, shape=(n_frame, n_cols), n_samples=51, axis=1)
velocity_kymograph_x_tracks_g_dense = construct_spatial_time_MOSES_velocity_x(meantracks_g_dense, shape=(n_frame, n_cols), n_samples=51, axis=1)
velocity_kymograph_x_tracks_dense = velocity_kymograph_x_tracks_r_dense + velocity_kymograph_x_tracks_g_dense

# coplot the velocity kymographs for comparison.
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15,5))
ax[0].set_title('Motion Field')
ax[0].imshow(optflow_median_slice, cmap='RdBu_r', vmin=-10, vmax=10)
ax[0].axis('off')
ax[0].set_aspect('auto')
ax[1].set_title('1000 Superpixels (fixed)')
ax[1].imshow(velocity_kymograph_x_tracks, cmap='RdBu_r', vmin=-10, vmax=10)
ax[1].axis('off')
ax[1].set_aspect('auto')
ax[2].set_title('1000 Superpixels (dynamic)')
ax[2].imshow(velocity_kymograph_x_tracks_dense, cmap='RdBu_r', vmin=-10, vmax=10)
ax[2].axis('off')
ax[2].set_aspect('auto')
plt.show()


# =============================================================================
# Separately plot everything and save everything to produce plots. 
# =============================================================================
fig, ax = plt.subplots(figsize=(5,5))
ax.imshow(optflow_median_slice, cmap='RdBu_r', vmin=-10, vmax=10)
ax.set_aspect('auto')
fig.savefig('velocity_kymograph_optflow.svg', bbox_inches='tight')
plt.show()

fig, ax = plt.subplots(figsize=(5,5))
ax.imshow(velocity_kymograph_x_tracks, cmap='RdBu_r', vmin=-10, vmax=10)
ax.set_aspect('auto')
fig.savefig('velocity_kymograph_spixel_tracks(fixed).svg', bbox_inches='tight')
plt.show()

fig, ax = plt.subplots(figsize=(5,5))
ax.imshow(velocity_kymograph_x_tracks_dense, cmap='RdBu_r', vmin=-10, vmax=10)
ax.set_aspect('auto')
fig.savefig('velocity_kymograph_spixel_tracks(dense).svg', bbox_inches='tight')
plt.show()

"""
plot a colorbar so we can cut it out to add in the images
"""
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from mpl_toolkits.axes_grid1.colorbar import colorbar
fig, ax = plt.subplots(figsize=(5,5))
Q = ax.imshow(velocity_kymograph_x_tracks_dense, cmap='RdBu_r', vmin=-10, vmax=10)
ax.set_aspect('auto')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05) 
plt.colorbar(Q, cax=cax)

fig.savefig('velocity_kymograph_spixel_tracks(dense)_w_colorbar.svg', bbox_inches='tight')
plt.show()



"""
4. Temporal tracking of the boundary using the superpixel tracks.
"""
from MOSES.Motion_Analysis.wound_statistics_tools import boundary_superpixel_meantracks_RGB

#boundary_curves, curves_lines, curve_img, boundary_line = (rgb_video, meantracks_r, meantracks_g, movement_thresh=0.2, t_av_motion=3, robust=False, lenient=True, debug_visual=False, max_dist=1.5, y_bins=50, y_frac_thresh=0.60):
boundary_curves, curves_lines, curve_img, boundary_line = boundary_superpixel_meantracks_RGB(vidstack, meantracks_r, meantracks_g, 
                                                                                             movement_thresh=0.2, 
                                                                                             t_av_motion=5, 
                                                                                             robust=False, lenient=False, 
                                                                                             debug_visual=False, 
                                                                                             max_dist=1.5, y_bins=50, 
                                                                                             y_frac_thresh=0.50)
# =============================================================================
# Plots boundary line onto frame.
# =============================================================================
# mark line onto the video frame. 
from skimage.morphology import binary_dilation, square
wound_frame_img = curve_img[-1] > 0; wound_frame_img = binary_dilation(wound_frame_img, square(15))
wound_frame_img = np.dstack([wound_frame_img, 
                             wound_frame_img,
                             wound_frame_img])

fig, ax = plt.subplots()
image_and_wound = (vidstack[-1])
image_and_wound[wound_frame_img>0] = 255
ax.imshow(image_and_wound)
plt.show()
# save the line marked frame
from skimage.io import imsave
#imsave('Frame_144_example_w_wound-line.tif', image_and_wound)

# visualise the wound finding (overlay with the video kymograph)
fig, ax = plt.subplots()
ax.imshow(vid_max_slice)
ax.plot(boundary_line[:,0], boundary_line[:,1], 'k-', lw=5)
ax.set_aspect('auto')
#fig.savefig('video_kymograph_example_with_wound-line_black.svg', bbox_inches='tight')
plt.show()

"""
5. Gap Closure Finding
"""
from MOSES.Motion_Analysis.wound_close_sweepline_area_segmentation import wound_sweep_area_segmentation

spixel_size = meantracks_r[1,0,1] - meantracks_r[1,0,0]
wound_closure_frame = wound_sweep_area_segmentation(vidstack, spixel_size, max_frame=50, n_sweeps=50, n_points_keep=1, n_clusters=2, p_als=0.001, to_plot=True)
print('predicted gap closure frame is: %d' %(wound_closure_frame))


"""
6. VCCF computation
"""
from MOSES.Motion_Analysis.mesh_statistics_tools import compute_max_vccf_cells_before_after_gap

(max_vccf_before, vccf_before), (max_vccf_after, vccf_after) = compute_max_vccf_cells_before_after_gap(meantracks_r, meantracks_g, wound_heal_frame=wound_closure_frame, err_frame=5)

print('max velocity cross-correlation before: %.3f' %(max_vccf_before))
print('max velocity cross-correlation after: %.3f' %(max_vccf_after))


"""
7. Spatial Correlation Index
"""
from MOSES.Motion_Analysis.mesh_statistics_tools import compute_spatial_correlation_function

spatial_corr_curve, (spatial_corr_pred, a_value, b_value, r_value) = compute_spatial_correlation_function(meantracks_r, wound_closure_frame, wound_heal_err=5, dist_range=np.arange(1,6,1))

# plot the curve and the fitted curve to y=a*exp(-x/b) to get the (a,b) parameters. 
plt.figure()
plt.title('Fitted Spatial Correlation: a=%.3f, b=%.3f' %(a_value, b_value))
plt.plot(np.arange(1,6,1), spatial_corr_curve, 'ko', label='measured')
plt.plot(np.arange(1,6,1), spatial_corr_pred, 'g-', label='fitted')
plt.xlabel('Distance (Number of Superpixels)')
plt.ylabel('Spatial Correlation')
plt.legend(loc='best')
# plt.savefig('Example spatial correlation index.svg', dpi=300, bbox_inches='tight')
plt.show()

"""
Example Motion map analysis using a subset of the full dataset. The full dataset can be downloaded at Mendeley Datasets with DOI: https://dx.doi.org/10.17632/j8yrmntc7x.1 and https://dx.doi.org/10.17632/vrhtsdhprr.1
"""
# 1. construct a function to detect all subfolders as experiment. 
from MOSES.Utility_Functions.file_io import detect_experiments

# detect experiment folders as subfolders under a top-level directory. (replace rootfolder with the correct folder path)
rootfolder = '../Data/Motion_Map_Videos'
expt_folders = detect_experiments(rootfolder, exclude=['meantracks','optflow'], level1=False)
print(expt_folders) # print the detected folder names. 

# 2. detect individual video files and label in terms of 0% or 5% FBS. 
import glob
import numpy as np 
# detect each .tif file in each folder. 
videofiles = [glob.glob(os.path.join(rootfolder, expt_folder, '*.tif')) for expt_folder in expt_folders]
# should give now [[0,0,0],[1,1,1]]
labels = [[i]*len(videofiles[i]) for i in range(len(videofiles))]
# flatten everything into single array
videofiles = np.hstack(videofiles)
labels = np.hstack(labels)

# 3. iterate and compute meantracks and MOSES mesh strain curves. 
from MOSES.Utility_Functions.file_io import read_multiimg_PIL
from MOSES.Motion_Analysis.mesh_statistics_tools import construct_MOSES_mesh, compute_MOSES_mesh_strain_curve

# set motion extraction parameters.
n_spixels = 1000
optical_flow_params = dict(pyr_scale=0.5, levels=3, winsize=15, iterations=3, poly_n=5, poly_sigma=1.2, flags=0)

n_videos = len(videofiles)
# initialise arrays to save computed data
mesh_strain_all = []

for ii in range(n_videos):
    videofile = videofiles[ii]
    print(ii, videofile)
    vidstack = read_multiimg_PIL(videofile)

    # 1. compute superpixel tracks 
    _, meantracks_r =  compute_grayscale_vid_superpixel_tracks(vidstack[:,:,:,0], optical_flow_params, n_spixels)
    _, meantracks_g =  compute_grayscale_vid_superpixel_tracks(vidstack[:,:,:,1], optical_flow_params, n_spixels)
    spixel_size = meantracks_r[1,0,1] - meantracks_r[1,0,0]
    # 2a. compute MOSES mesh.
    MOSES_mesh_strain_time_r, MOSES_mesh_neighborlist_r = construct_MOSES_mesh(meantracks_r, dist_thresh=1.2, spixel_size=spixel_size)
    MOSES_mesh_strain_time_g, MOSES_mesh_neighborlist_g = construct_MOSES_mesh(meantracks_g, dist_thresh=1.2, spixel_size=spixel_size)
    # 2b. compute the MOSES mesh strain curve for the video. 
    mesh_strain_r = compute_MOSES_mesh_strain_curve(MOSES_mesh_strain_time_r, normalise=False)
    mesh_strain_g = compute_MOSES_mesh_strain_curve(MOSES_mesh_strain_time_g, normalise=False)
    mesh_strain_curve_video = .5*(mesh_strain_r+mesh_strain_g)
    # (optional normalization)
    normalised_mesh_strain_curve_video = mesh_strain_curve_video/ np.max(mesh_strain_curve_video)
    # 3. append the computed mesh strain curves. 
    mesh_strain_all.append(normalised_mesh_strain_curve_video)
    
# stack all the mesh strain curves into one 
mesh_strain_all = np.vstack(mesh_strain_all)


# PCA motion map.
from sklearn.decomposition import PCA
# initialise the PCA model
pca_model = PCA(n_components=2, random_state=0)
# 1. learn the PCA using only the 5% FBS a.k.a label=1
pca_5_percent_mesh = pca_model.fit_transform(mesh_strain_all[labels==1])
# 2. project the 0% mesh strains 
pca_0_percent_mesh = pca_model.transform(mesh_strain_all[labels==0])


# plotting
fig, ax = plt.subplots(figsize=(3,3))
ax.plot(pca_5_percent_mesh[:,0], pca_5_percent_mesh[:,1], 'o', ms=10, color='g', label='5% FBS')
ax.plot(pca_0_percent_mesh[:,0], pca_0_percent_mesh[:,1], 'o', ms=10, color='r', label='0% FBS')
ax.set_xlim([-2,2])
ax.set_ylim([-2,2])
plt.legend(loc='best')
# fig.savefig('example_motion_map.svg', dpi=300, bbox_inches='tight')
plt.show()


"""
Demonstration of track cleaning based on motion information for two cell epithelial sheets
"""
from MOSES.Track_Filtering.filter_meantracks_superpixels import filter_red_green_tracks

meantracks_r_filt, meantracks_g_filt = filter_red_green_tracks(meantracks_r, meantracks_g, img_shape=(n_rows, n_cols), frame2=1)



